{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from catalyst import dl\n",
    "from src.runners import DistilMLMRunner\n",
    "from src.models import DistilbertStudentModel, BertForMLM\n",
    "from catalyst.core import MetricAggregationCallback\n",
    "from torch.utils.data import DataLoader\n",
    "from src.callbacks import (\n",
    "    CosineLossCallback,\n",
    "    KLDivLossCallback,\n",
    "    MaskedLanguageModelCallback,\n",
    "    MSELossCallback,\n",
    "    PerplexityMetricCallbackDistillation,\n",
    ")\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "PATH_TO_YOUR_DATASET = \"./data\"\n",
    "train_df = pd.read_csv(f\"{PATH_TO_YOUR_DATASET}/train.csv\")\n",
    "valid_df = pd.read_csv(f\"{PATH_TO_YOUR_DATASET}/valid.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from catalyst.contrib.data.nlp import LanguageModelingDataset\n",
    "from transformers import AutoTokenizer\n",
    "from transformers.data.data_collator import DataCollatorForLanguageModeling\n",
    "\n",
    "model_name = \"bert-based-uncased\"\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"bert-base-uncased\")\n",
    "\n",
    "train_dataset = LanguageModelingDataset(train_df[\"text\"], tokenizer)\n",
    "valid_dataset = LanguageModelingDataset(valid_df[\"text\"], tokenizer)\n",
    "\n",
    "collate_fn = DataCollatorForLanguageModeling(tokenizer).collate_batch\n",
    "train_dataloader = DataLoader(\n",
    "    train_dataset, collate_fn=collate_fn, batch_size=2\n",
    ")\n",
    "valid_dataloader = DataLoader(\n",
    "    valid_dataset, collate_fn=collate_fn, batch_size=2\n",
    ")\n",
    "loaders = {\"train\": train_dataloader, \"valid\": valid_dataloader}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "teacher = BertForMLM(model_name)\n",
    "student = DistilbertStudentModel(model_name)\n",
    "\n",
    "model = torch.nn.ModuleDict({\"teacher\": teacher, \"student\": student})\n",
    "\n",
    "callbacks = {\n",
    "    \"masked_lm_loss\": MaskedLanguageModelCallback(),\n",
    "    \"mse_loss\": MSELossCallback(),\n",
    "    \"cosine_loss\": CosineLossCallback(),\n",
    "    \"kl_div_loss\": KLDivLossCallback(),\n",
    "    \"loss\": MetricAggregationCallback(\n",
    "        prefix=\"loss\",\n",
    "        mode=\"weighted_sum\",\n",
    "        metrics={\n",
    "            \"cosine_loss\": 1.0,\n",
    "            \"masked_lm_loss\": 1.0,\n",
    "            \"kl_div_loss\": 1.0,\n",
    "            \"mse_loss\": 1.0\n",
    "        }\n",
    "    ),\n",
    "    \"optimizer\": dl.OptimizerCallback(),\n",
    "    \"perplexity\": PerplexityMetricCallbackDistillation()\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "runner = DistilMLMRunner(device=torch.device(\"cuda\"))\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)\n",
    "runner.train(\n",
    "    model=model,\n",
    "    optimizer=optimizer,\n",
    "    criterion=criterion,\n",
    "    loaders=loaders,\n",
    "    verbose=True,\n",
    "    check=True,\n",
    "    callbacks=callbacks,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:.conda-dbert]",
   "language": "python",
   "name": "conda-env-.conda-dbert-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}